TBATS forecaster (multiple seasonality)

TBATS forecaster (multiple seasonality)#

TBATS extends BATS with Trigonometric seasonality (Fourier-style terms), which is especially useful when one of the seasonal periods is large (e.g., yearly seasonality with daily data).

This notebook implements a practical TBATS-style forecaster with an interface like:

  • TBATS(use_box_cox=..., box_cox_bounds=..., seasonal_periods=..., seasonal_harmonics=..., use_arma_errors=...)

  • model = tbats.fit(y)

  • forecast = model.forecast(steps)

Implementation note: the original TBATS model is state-space / exponential smoothing based. Here we implement a TBATS-style forecaster using:

  • explicit trend + Fourier seasonal design matrices, and

  • ARMA errors estimated via statsmodels (SARIMAX with d=0).

Model sketch (math)#

As with BATS, optionally apply Box–Cox and model \(x_t=g_\lambda(y_t)\).

TBATS uses a Fourier (trigonometric) seasonal representation for each seasonal period \(m\): $\(S_t^{(m)} = \sum_{k=1}^{K} \left(a_k\cos\left(\frac{2\pi k t}{m}\right) + b_k\sin\left(\frac{2\pi k t}{m}\right)\right).\)$

This uses \(2K\) parameters per seasonality instead of \(m-1\) seasonal dummies (BATS), which is much smaller when \(m\) is large.

import warnings

import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import plotly.io as pio

from scipy import stats
import statsmodels.api as sm

warnings.filterwarnings("ignore", category=UserWarning)

pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"

rng = np.random.default_rng(7)

import numpy, pandas, scipy, statsmodels, plotly
print("numpy:", numpy.__version__)
print("pandas:", pandas.__version__)
print("scipy:", scipy.__version__)
print("statsmodels:", statsmodels.__version__)
print("plotly:", plotly.__version__)
numpy: 1.26.2
pandas: 2.1.3
scipy: 1.15.0
statsmodels: 0.14.4
plotly: 6.5.2
class BoxCoxTransformer:
    def __init__(self, use_box_cox: bool, box_cox_bounds: tuple[float, float] = (0.0, 1.0)):
        self.use_box_cox = bool(use_box_cox)
        self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
        self.shift_: float = 0.0
        self.lambda_: float | None = None

    def fit(self, y: np.ndarray) -> "BoxCoxTransformer":
        y = np.asarray(y, dtype=float)
        if not self.use_box_cox:
            self.shift_ = 0.0
            self.lambda_ = None
            return self

        min_y = float(np.min(y))
        self.shift_ = 0.0 if min_y > 0.0 else (1.0 - min_y)
        y_pos = y + self.shift_
        if np.any(y_pos <= 0.0):
            raise ValueError("Box-Cox requires strictly positive data (even after shift).")

        lo, hi = self.box_cox_bounds
        self.lambda_ = float(stats.boxcox_normmax(y_pos, brack=(lo, hi), method="mle"))
        return self

    def transform(self, y: np.ndarray) -> np.ndarray:
        y = np.asarray(y, dtype=float)
        if not self.use_box_cox:
            return y.copy()
        if self.lambda_ is None:
            raise RuntimeError("Call fit() before transform().")

        y_pos = y + self.shift_
        if np.any(y_pos <= 0.0):
            raise ValueError("Box-Cox requires strictly positive data (even after shift).")

        lmbda = float(self.lambda_)
        if abs(lmbda) < 1e-10:
            return np.log(y_pos)
        return (np.power(y_pos, lmbda) - 1.0) / lmbda

    def inverse_transform(self, x: np.ndarray) -> np.ndarray:
        x = np.asarray(x, dtype=float)
        if not self.use_box_cox:
            return x.copy()
        if self.lambda_ is None:
            raise RuntimeError("Call fit() before inverse_transform().")

        lmbda = float(self.lambda_)
        if abs(lmbda) < 1e-10:
            y_pos = np.exp(x)
        else:
            y_pos = np.power(lmbda * x + 1.0, 1.0 / lmbda)
        return y_pos - self.shift_


def _acf(x: np.ndarray, max_lag: int) -> tuple[np.ndarray, np.ndarray]:
    x = np.asarray(x, dtype=float)
    x = x - x.mean()
    denom = float(np.dot(x, x))
    lags = np.arange(max_lag + 1)
    values = np.zeros(max_lag + 1)
    values[0] = 1.0
    if denom == 0.0:
        return lags, values
    for k in range(1, max_lag + 1):
        values[k] = float(np.dot(x[k:], x[:-k]) / denom)
    return lags, values


def trend_feature(t: np.ndarray, *, use_damped: bool, damped_phi: float) -> np.ndarray:
    t = np.asarray(t, dtype=float)
    if not use_damped:
        return t
    phi = float(damped_phi)
    if not (0.0 < phi < 1.0):
        raise ValueError("damped_phi must be in (0, 1)")
    return (1.0 - np.power(phi, t)) / (1.0 - phi)


def fourier_terms(t: np.ndarray, period: int, K: int) -> np.ndarray:
    t = np.asarray(t, dtype=float)
    period = int(period)
    K = int(K)
    if period <= 1 or K <= 0:
        return np.zeros((t.size, 0), dtype=float)

    K = min(K, period // 2)
    cols: list[np.ndarray] = []
    for k in range(1, K + 1):
        ang = 2.0 * np.pi * k * t / period
        cols.append(np.cos(ang))
        cols.append(np.sin(ang))
    return np.column_stack(cols).astype(float)


def tbats_design_matrix(
    t: np.ndarray,
    *,
    use_trend: bool,
    use_damped_trend: bool,
    damped_trend_phi: float,
    seasonal_periods: list[int] | None,
    seasonal_harmonics: list[int] | None,
) -> np.ndarray:
    t = np.asarray(t, dtype=int)
    cols = [np.ones((t.size, 1), dtype=float)]
    if use_trend:
        cols.append(trend_feature(t.astype(float), use_damped=use_damped_trend, damped_phi=damped_trend_phi).reshape(-1, 1))

    if seasonal_periods:
        periods = [int(m) for m in seasonal_periods]
        if seasonal_harmonics is None:
            harmonics = [min(10, m // 2) for m in periods]
        else:
            harmonics = [int(k) for k in seasonal_harmonics]
            if len(harmonics) != len(periods):
                raise ValueError("seasonal_harmonics must match seasonal_periods length")

        for m, K in zip(periods, harmonics):
            cols.append(fourier_terms(t.astype(float), period=m, K=K))

    return np.concatenate(cols, axis=1)
class TBATSModel:
    def __init__(
        self,
        *,
        results,
        transformer: BoxCoxTransformer,
        use_trend: bool,
        use_damped_trend: bool,
        damped_trend_phi: float,
        seasonal_periods: list[int] | None,
        seasonal_harmonics: list[int] | None,
        y_index,
    ):
        self.results = results
        self.transformer = transformer
        self.use_trend = use_trend
        self.use_damped_trend = use_damped_trend
        self.damped_trend_phi = float(damped_trend_phi)
        self.seasonal_periods = seasonal_periods
        self.seasonal_harmonics = seasonal_harmonics
        self.y_index = y_index

    @property
    def n_obs(self) -> int:
        return int(self.results.nobs)

    def fitted_values(self) -> np.ndarray:
        fitted_x = np.asarray(self.results.fittedvalues, dtype=float)
        return self.transformer.inverse_transform(fitted_x)

    def residuals(self) -> np.ndarray:
        return np.asarray(self.results.resid, dtype=float)

    def forecast(self, steps: int, *, alpha: float = 0.05) -> dict[str, np.ndarray]:
        steps = int(steps)
        t_future = np.arange(self.n_obs, self.n_obs + steps)
        X_future = tbats_design_matrix(
            t_future,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
            seasonal_harmonics=self.seasonal_harmonics,
        )

        fcst = self.results.get_forecast(steps=steps, exog=X_future)
        mean_x = np.asarray(fcst.predicted_mean, dtype=float)

        ci = fcst.conf_int(alpha=alpha)
        ci_np = np.asarray(ci)
        lower_x = ci_np[:, 0]
        upper_x = ci_np[:, 1]

        mean_y = self.transformer.inverse_transform(mean_x)
        lower_y = self.transformer.inverse_transform(lower_x)
        upper_y = self.transformer.inverse_transform(upper_x)

        return {"mean": mean_y, "lower": lower_y, "upper": upper_y}


class TBATS:
    def __init__(
        self,
        *,
        use_box_cox: bool = False,
        box_cox_bounds: tuple[float, float] = (0.0, 1.0),
        use_trend: bool = True,
        use_damped_trend: bool = False,
        damped_trend_phi: float = 0.98,
        seasonal_periods: list[int] | None = None,
        seasonal_harmonics: list[int] | None = None,
        use_arma_errors: bool = True,
        arma_order: tuple[int, int] | None = (1, 1),
        max_arma_order: int = 1,
        show_warnings: bool = True,
    ):
        self.use_box_cox = bool(use_box_cox)
        self.box_cox_bounds = tuple(float(v) for v in box_cox_bounds)
        self.use_trend = bool(use_trend)
        self.use_damped_trend = bool(use_damped_trend)
        self.damped_trend_phi = float(damped_trend_phi)
        self.seasonal_periods = None if seasonal_periods is None else [int(m) for m in seasonal_periods]
        self.seasonal_harmonics = None if seasonal_harmonics is None else [int(k) for k in seasonal_harmonics]
        self.use_arma_errors = bool(use_arma_errors)
        self.arma_order = None if arma_order is None else (int(arma_order[0]), int(arma_order[1]))
        self.max_arma_order = int(max_arma_order)
        self.show_warnings = bool(show_warnings)

    def _fit_sarimax(self, y_x: np.ndarray, X: np.ndarray, order: tuple[int, int]) -> tuple[object, float]:
        p, q = order
        res = sm.tsa.SARIMAX(
            y_x,
            exog=X,
            order=(p, 0, q),
            trend="n",
            enforce_stationarity=True,
            enforce_invertibility=True,
        ).fit(disp=False, method="lbfgs", maxiter=300)
        return res, float(res.aic)

    def _select_arma_order(self, y_x: np.ndarray, X: np.ndarray) -> tuple[int, int]:
        candidates = []
        for p in range(self.max_arma_order + 1):
            for q in range(self.max_arma_order + 1):
                candidates.append((p, q))

        best_order = (0, 0)
        best_aic = np.inf

        for order in candidates:
            try:
                _, aic = self._fit_sarimax(y_x, X, order)
            except Exception:
                continue
            if aic < best_aic:
                best_aic = aic
                best_order = order

        if best_aic == np.inf:
            raise RuntimeError("Failed to fit any ARMA(p,q) candidate.")
        return best_order

    def fit(self, y) -> TBATSModel:
        if isinstance(y, pd.Series):
            y_index = y.index
            y_np = y.to_numpy(dtype=float)
        else:
            y_index = None
            y_np = np.asarray(y, dtype=float)

        t = np.arange(y_np.size)
        X = tbats_design_matrix(
            t,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
            seasonal_harmonics=self.seasonal_harmonics,
        )

        transformer = BoxCoxTransformer(self.use_box_cox, box_cox_bounds=self.box_cox_bounds).fit(y_np)
        y_x = transformer.transform(y_np)

        if not self.use_arma_errors:
            chosen_order = (0, 0)
        elif self.arma_order is not None:
            chosen_order = self.arma_order
        else:
            chosen_order = self._select_arma_order(y_x, X)

        res, aic = self._fit_sarimax(y_x, X, chosen_order)
        if self.show_warnings:
            print(f"Chosen ARMA(p,q) = {chosen_order}, AIC = {aic:.2f}")

        return TBATSModel(
            results=res,
            transformer=transformer,
            use_trend=self.use_trend,
            use_damped_trend=self.use_damped_trend,
            damped_trend_phi=self.damped_trend_phi,
            seasonal_periods=self.seasonal_periods,
            seasonal_harmonics=self.seasonal_harmonics,
            y_index=y_index,
        )

Demo: long seasonality (weekly + yearly)#

We’ll simulate daily data with two seasonalities:

  • weekly (\(m_1=7\))

  • yearly (\(m_2=365\))

A BATS dummy-season model would need roughly \((7-1)+(365-1)=370\) seasonal parameters (plus trend). TBATS can represent the same seasonal patterns with a small number of harmonics.

def simulate_arma11(n: int, *, phi: float, theta: float, sigma: float, rng: np.random.Generator) -> np.ndarray:
    eps = rng.normal(0.0, sigma, size=n)
    u = np.zeros(n)
    for t in range(n):
        ar = phi * u[t - 1] if t - 1 >= 0 else 0.0
        ma = theta * eps[t - 1] if t - 1 >= 0 else 0.0
        u[t] = ar + eps[t] + ma
    return u


n = 4 * 365
idx = pd.date_range("2018-01-01", periods=n, freq="D")
t = np.arange(n)

weekly = 1.0 * np.sin(2 * np.pi * t / 7) + 0.3 * np.cos(2 * np.pi * t / 7)
yearly = 3.5 * np.sin(2 * np.pi * t / 365) + 1.7 * np.cos(2 * np.pi * t / 365)
trend = 0.003 * t

noise = simulate_arma11(n, phi=0.5, theta=0.3, sigma=0.9, rng=rng)

y = 80.0 + trend + weekly + yearly + noise
y = pd.Series(y, index=idx, name="y")

fig = go.Figure()
fig.add_trace(go.Scatter(x=y.index, y=y, name="y", line=dict(color="black")))
fig.update_layout(title="Synthetic long-seasonality series", xaxis_title="date", yaxis_title="value")
fig.show()
# Compare parameter counts: BATS dummies vs TBATS Fourier
m1, m2 = 7, 365
bats_seasonal_params = (m1 - 1) + (m2 - 1)

K_weekly = 3
K_yearly = 10
tbats_seasonal_params = 2 * K_weekly + 2 * K_yearly

print("BATS seasonal parameters (dummies):", bats_seasonal_params)
print("TBATS seasonal parameters (Fourier):", tbats_seasonal_params)
BATS seasonal parameters (dummies): 370
TBATS seasonal parameters (Fourier): 26
# Train/test split + fit TBATS
h = 90
y_train = y.iloc[:-h]
y_test = y.iloc[-h:]

tbats = TBATS(
    use_box_cox=False,
    box_cox_bounds=(0.0, 1.0),
    use_trend=True,
    use_damped_trend=False,
    seasonal_periods=[7, 365],
    seasonal_harmonics=[K_weekly, K_yearly],
    use_arma_errors=True,
    arma_order=(1, 1),
    show_warnings=True,
)

model = tbats.fit(y_train)
fcst = model.forecast(h)

fitted = pd.Series(model.fitted_values(), index=y_train.index)
pred_mean = pd.Series(fcst["mean"], index=y_test.index)
pred_lower = pd.Series(fcst["lower"], index=y_test.index)
pred_upper = pd.Series(fcst["upper"], index=y_test.index)

fig = go.Figure()
fig.add_trace(go.Scatter(x=y_train.index, y=y_train, name="train", line=dict(color="rgba(0,0,0,0.35)")))
fig.add_trace(go.Scatter(x=y_train.index, y=fitted, name="fitted", line=dict(color="#59A14F")))
fig.add_trace(go.Scatter(x=y_test.index, y=y_test, name="test", line=dict(color="black")))

fig.add_trace(go.Scatter(x=y_test.index, y=pred_upper, line=dict(width=0), showlegend=False))
fig.add_trace(
    go.Scatter(
        x=y_test.index,
        y=pred_lower,
        fill="tonexty",
        fillcolor="rgba(89,161,79,0.18)",
        line=dict(width=0),
        name="95% interval (approx)",
    )
)
fig.add_trace(go.Scatter(x=y_test.index, y=pred_mean, name="forecast mean", line=dict(color="#E15759")))

fig.update_layout(title="TBATS forecast on long-seasonality series", xaxis_title="date", yaxis_title="value")
fig.show()
Chosen ARMA(p,q) = (1, 1), AIC = 3512.80
# Residual diagnostics (in transformed space)
resid = model.residuals()
warmup = 10
resid_use = resid[warmup:]

print("residual mean:", float(resid_use.mean()))
print("residual std:", float(resid_use.std(ddof=1)))
print("Jarque-Bera:", stats.jarque_bera(resid_use))

lags, acf_vals = _acf(resid_use, max_lag=30)
bound = 1.96 / np.sqrt(resid_use.size)

# QQ data
nq = resid_use.size
p = (np.arange(1, nq + 1) - 0.5) / nq
theoretical = stats.norm.ppf(p)
sample_q = np.sort((resid_use - resid_use.mean()) / resid_use.std(ddof=1))

fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Residuals over time", "Residual histogram", "Residual ACF", "QQ plot (std residuals)"),
)

fig.add_trace(go.Scatter(x=y_train.index[warmup:], y=resid_use, name="residuals", line=dict(color="#59A14F")), row=1, col=1)
fig.add_hline(y=0, line=dict(color="black", dash="dash"), row=1, col=1)

fig.add_trace(go.Histogram(x=resid_use, nbinsx=30, name="hist", marker_color="#59A14F"), row=1, col=2)

fig.add_trace(go.Bar(x=lags, y=acf_vals, name="ACF(resid)", marker_color="#59A14F"), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[bound, bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=[0, lags.max()], y=[-bound, -bound], mode="lines", line=dict(color="gray", dash="dash"), showlegend=False), row=2, col=1)

fig.add_trace(go.Scatter(x=theoretical, y=sample_q, mode="markers", name="QQ", marker=dict(color="#59A14F")), row=2, col=2)
fig.add_trace(
    go.Scatter(x=[theoretical.min(), theoretical.max()], y=[theoretical.min(), theoretical.max()], mode="lines", line=dict(color="black", dash="dash"), showlegend=False),
    row=2,
    col=2,
)

fig.update_layout(height=750, title="TBATS residual diagnostics")
fig.show()
residual mean: 5.158801683041921e-05
residual std: 0.8543250028588507
Jarque-Bera: SignificanceResult(statistic=0.5576490029991557, pvalue=0.7566726864848627)